#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import NamedTuple, Optional, Tuple, Union

import torch
import torch.nn as nn

from .rasterize_points import rasterize_points


# Class to store the outputs of point rasterization
class PointFragments(NamedTuple):
    idx: torch.Tensor
    zbuf: torch.Tensor
    dists: torch.Tensor


@dataclass
class PointsRasterizationSettings:
    """
    Class to store the point rasterization params with defaults

    Members:
        image_size: Either common height and width or (height, width), in pixels.
        radius: The radius (in NDC units) of each disk to be rasterized.
            This can either be a float in which case the same radius is used
            for each point, or a torch.Tensor of shape (N, P) giving a radius
            per point in the batch.
        points_per_pixel: (int) Number of points to keep track of per pixel.
            We return the nearest points_per_pixel points along the z-axis.
        bin_size: Size of bins to use for coarse-to-fine rasterization. Setting
            bin_size=0 uses naive rasterization; setting bin_size=None attempts
            to set it heuristically based on the shape of the input. This should
            not affect the output, but can affect the speed of the forward pass.
        max_points_per_bin: Only applicable when using coarse-to-fine
            rasterization (bin_size != 0); this is the maximum number of points
            allowed within each bin. This should not affect the output values,
            but can affect the memory usage in the forward pass.
            Setting max_points_per_bin=None attempts to set with a heuristic.
    """

    image_size: Union[int, Tuple[int, int]] = 256
    radius: Union[float, torch.Tensor] = 0.01
    points_per_pixel: int = 8
    bin_size: Optional[int] = None
    max_points_per_bin: Optional[int] = None


class PointsRasterizer(nn.Module):
    """
    This class implements methods for rasterizing a batch of pointclouds.
    """

    def __init__(self, cameras=None, raster_settings=None) -> None:
        """
        cameras: A cameras object which has a  `transform_points` method
                which returns the transformed points after applying the
                world-to-view and view-to-ndc transformations.
            raster_settings: the parameters for rasterization. This should be a
                named tuple.

        All these initial settings can be overridden by passing keyword
        arguments to the forward function.
        """
        super().__init__()
        if raster_settings is None:
            raster_settings = PointsRasterizationSettings()

        self.cameras = cameras
        self.raster_settings = raster_settings

    def transform(self, point_clouds, **kwargs) -> torch.Tensor:
        """
        Args:
            point_clouds: a set of point clouds

        Returns:
            points_proj: the points with positions projected
            in NDC space

        NOTE: keeping this as a separate function for readability but it could
        be moved into forward.
        """
        cameras = kwargs.get("cameras", self.cameras)
        if cameras is None:
            msg = "Cameras must be specified either at initialization \
                or in the forward pass of PointsRasterizer"
            raise ValueError(msg)

        pts_world = point_clouds.points_padded()
        # NOTE: Retaining view space z coordinate for now.
        # TODO: Remove this line when the convention for the z coordinate in
        # the rasterizer is decided. i.e. retain z in view space or transform
        # to a different range.
        eps = kwargs.get("eps", None)
        pts_view = cameras.get_world_to_view_transform(**kwargs).transform_points(
            pts_world, eps=eps
        )
        # view to NDC transform
        to_ndc_transform = cameras.get_ndc_camera_transform(**kwargs)
        projection_transform = cameras.get_projection_transform(**kwargs).compose(
            to_ndc_transform
        )
        pts_ndc = projection_transform.transform_points(pts_view, eps=eps)

        pts_ndc[..., 2] = pts_view[..., 2]
        point_clouds = point_clouds.update_padded(pts_ndc)
        return point_clouds

    def to(self, device):
        # Manually move to device cameras as it is not a subclass of nn.Module
        self.cameras = self.cameras.to(device)
        return self

    def forward(self, point_clouds, **kwargs) -> PointFragments:
        """
        Args:
            point_clouds: a set of point clouds with coordinates in world space.
        Returns:
            PointFragments: Rasterization outputs as a named tuple.
        """
        points_proj = self.transform(point_clouds, **kwargs)
        raster_settings = kwargs.get("raster_settings", self.raster_settings)
        idx, zbuf, dists2 = rasterize_points(
            points_proj,
            image_size=raster_settings.image_size,
            radius=raster_settings.radius,
            points_per_pixel=raster_settings.points_per_pixel,
            bin_size=raster_settings.bin_size,
            max_points_per_bin=raster_settings.max_points_per_bin,
        )
        return PointFragments(idx=idx, zbuf=zbuf, dists=dists2)
